{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "89a363cd-8944-475a-88ea-4401785218c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b45cbeda-65bc-4b10-9f9c-9afd43a54d20",
   "metadata": {},
   "source": [
    "# Channel Selection in Multivariate Time Series Classification \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08decb5b-8dfb-4666-a66b-3a2349960956",
   "metadata": {},
   "source": [
    "## Overview"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da743484-17d3-4cec-8eaa-8d8293ce6f35",
   "metadata": {},
   "source": [
    "Sometimes every channel is not required to perform classification; only a few are useful.  The [1] proposed a fast channel selection technique for Multivariate Time Classification. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcbe2174-a691-4093-ab80-d796edb5121d",
   "metadata": {},
   "source": [
    "[1] : Fast Channel Selection for Scalable Multivariate Time Series Classification [Link](https://www.researchgate.net/publication/354445008_Fast_Channel_Selection_for_Scalable_Multivariate_Time_Series_Classification)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1779970-eefb-4577-9c4e-e0a19ceadcc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import RidgeClassifierCV\n",
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "from sktime.datasets import load_UCR_UEA_dataset\n",
    "from sktime.transformations.panel import channel_selection\n",
    "from sktime.transformations.panel.rocket import Rocket"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0437ca7a-5b5a-4e28-b565-0b2df4eac60d",
   "metadata": {},
   "source": [
    "# 1 Initialise the Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "830137a3-10c3-49b9-9a98-7062dc7ab1d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cs = channel_selection.ElbowClassSum()  # ECS\n",
    "cs = channel_selection.ElbowClassPairwise()  # ECP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "89443793-7cf0-4a4c-a4b1-d928a46a1bb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "rocket_pipeline = make_pipeline(cs, Rocket(), RidgeClassifierCV())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a268cc1-d5bf-4b02-916b-f3417c1cd3ff",
   "metadata": {},
   "source": [
    "# 2 Load and Fit the Training Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "68f508e3-ecc9-4b3a-b4de-7073cd1dfb90",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = \"BasicMotions\"\n",
    "X_train, y_train = load_UCR_UEA_dataset(data, split=\"train\", return_X_y=True)\n",
    "X_test, y_test = load_UCR_UEA_dataset(data, split=\"test\", return_X_y=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "94f421bc-e384-4b98-89af-111a4d8c378b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(steps=[('elbowclasspairwise', ElbowClassPairwise()),\n",
       "                ('rocket', Rocket()),\n",
       "                ('ridgeclassifiercv',\n",
       "                 RidgeClassifierCV(alphas=array([ 0.1,  1. , 10. ])))])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rocket_pipeline.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c867fdb-5126-44f6-b7f0-f999d3f60457",
   "metadata": {},
   "source": [
    "# 3 Classify the Test Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "04573f4d-0b61-4ab8-8355-79f0aa1ca04f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rocket_pipeline.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d18ac8bc-a83a-4dd7-b577-aefc25d7bed6",
   "metadata": {},
   "source": [
    "# 4 Identify channels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "35a44d68-7bce-44b0-baf3-e4f11606001c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rocket_pipeline.steps[0][1].channels_selected_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "358ab28f-edbe-49f4-95f5-8a7d0fb5d166",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Centroid_badminton_running</th>\n",
       "      <th>Centroid_badminton_standing</th>\n",
       "      <th>Centroid_badminton_walking</th>\n",
       "      <th>Centroid_running_standing</th>\n",
       "      <th>Centroid_running_walking</th>\n",
       "      <th>Centroid_standing_walking</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39.594679</td>\n",
       "      <td>55.752785</td>\n",
       "      <td>48.440779</td>\n",
       "      <td>63.610220</td>\n",
       "      <td>57.247383</td>\n",
       "      <td>10.717044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>57.681767</td>\n",
       "      <td>24.390543</td>\n",
       "      <td>27.770269</td>\n",
       "      <td>60.458125</td>\n",
       "      <td>62.339120</td>\n",
       "      <td>16.370347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20.175911</td>\n",
       "      <td>24.126969</td>\n",
       "      <td>22.331621</td>\n",
       "      <td>25.671979</td>\n",
       "      <td>22.991555</td>\n",
       "      <td>4.897452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>12.546212</td>\n",
       "      <td>12.439152</td>\n",
       "      <td>12.741854</td>\n",
       "      <td>6.317654</td>\n",
       "      <td>6.695743</td>\n",
       "      <td>3.585273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>10.101196</td>\n",
       "      <td>8.865871</td>\n",
       "      <td>9.221908</td>\n",
       "      <td>6.520172</td>\n",
       "      <td>6.715702</td>\n",
       "      <td>1.299989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>23.464251</td>\n",
       "      <td>14.568685</td>\n",
       "      <td>13.953445</td>\n",
       "      <td>18.878429</td>\n",
       "      <td>19.768549</td>\n",
       "      <td>7.228389</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Centroid_badminton_running  Centroid_badminton_standing  \\\n",
       "0                   39.594679                    55.752785   \n",
       "1                   57.681767                    24.390543   \n",
       "2                   20.175911                    24.126969   \n",
       "3                   12.546212                    12.439152   \n",
       "4                   10.101196                     8.865871   \n",
       "5                   23.464251                    14.568685   \n",
       "\n",
       "   Centroid_badminton_walking  Centroid_running_standing  \\\n",
       "0                   48.440779                  63.610220   \n",
       "1                   27.770269                  60.458125   \n",
       "2                   22.331621                  25.671979   \n",
       "3                   12.741854                   6.317654   \n",
       "4                    9.221908                   6.520172   \n",
       "5                   13.953445                  18.878429   \n",
       "\n",
       "   Centroid_running_walking  Centroid_standing_walking  \n",
       "0                 57.247383                  10.717044  \n",
       "1                 62.339120                  16.370347  \n",
       "2                 22.991555                   4.897452  \n",
       "3                  6.695743                   3.585273  \n",
       "4                  6.715702                   1.299989  \n",
       "5                 19.768549                   7.228389  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rocket_pipeline.steps[0][1].distance_frame_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c75f99ea-3966-483b-9f08-89e3f0dbffeb",
   "metadata": {},
   "source": [
    "# 5 Standalone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "82607728-1095-4f15-a06e-d2463bb5c642",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ElbowClassPairwise()"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cs.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1f3ec36-ce86-4388-b7b5-b3b2a087b43a",
   "metadata": {},
   "source": [
    "# 6 Distance Matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f4a19774-368e-43d4-a45a-d8109ae2d17f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Centroid_badminton_running</th>\n",
       "      <th>Centroid_badminton_standing</th>\n",
       "      <th>Centroid_badminton_walking</th>\n",
       "      <th>Centroid_running_standing</th>\n",
       "      <th>Centroid_running_walking</th>\n",
       "      <th>Centroid_standing_walking</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39.594679</td>\n",
       "      <td>55.752785</td>\n",
       "      <td>48.440779</td>\n",
       "      <td>63.610220</td>\n",
       "      <td>57.247383</td>\n",
       "      <td>10.717044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>57.681767</td>\n",
       "      <td>24.390543</td>\n",
       "      <td>27.770269</td>\n",
       "      <td>60.458125</td>\n",
       "      <td>62.339120</td>\n",
       "      <td>16.370347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20.175911</td>\n",
       "      <td>24.126969</td>\n",
       "      <td>22.331621</td>\n",
       "      <td>25.671979</td>\n",
       "      <td>22.991555</td>\n",
       "      <td>4.897452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>12.546212</td>\n",
       "      <td>12.439152</td>\n",
       "      <td>12.741854</td>\n",
       "      <td>6.317654</td>\n",
       "      <td>6.695743</td>\n",
       "      <td>3.585273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>10.101196</td>\n",
       "      <td>8.865871</td>\n",
       "      <td>9.221908</td>\n",
       "      <td>6.520172</td>\n",
       "      <td>6.715702</td>\n",
       "      <td>1.299989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>23.464251</td>\n",
       "      <td>14.568685</td>\n",
       "      <td>13.953445</td>\n",
       "      <td>18.878429</td>\n",
       "      <td>19.768549</td>\n",
       "      <td>7.228389</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Centroid_badminton_running  Centroid_badminton_standing  \\\n",
       "0                   39.594679                    55.752785   \n",
       "1                   57.681767                    24.390543   \n",
       "2                   20.175911                    24.126969   \n",
       "3                   12.546212                    12.439152   \n",
       "4                   10.101196                     8.865871   \n",
       "5                   23.464251                    14.568685   \n",
       "\n",
       "   Centroid_badminton_walking  Centroid_running_standing  \\\n",
       "0                   48.440779                  63.610220   \n",
       "1                   27.770269                  60.458125   \n",
       "2                   22.331621                  25.671979   \n",
       "3                   12.741854                   6.317654   \n",
       "4                    9.221908                   6.520172   \n",
       "5                   13.953445                  18.878429   \n",
       "\n",
       "   Centroid_running_walking  Centroid_standing_walking  \n",
       "0                 57.247383                  10.717044  \n",
       "1                 62.339120                  16.370347  \n",
       "2                 22.991555                   4.897452  \n",
       "3                  6.695743                   3.585273  \n",
       "4                  6.715702                   1.299989  \n",
       "5                 19.768549                   7.228389  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cs.distance_frame_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a29b0ece",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cs.train_time_"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "30ff7f6bb2505d289b6e6022e217e794dc64e9153f959b8a264cb3c597a35999"
  },
  "kernelspec": {
   "display_name": "Python 3.7.5 ('sktime-test')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
